import copy
import json
import os

import numpy as np
import torch

from ModularUtils.Experiment_Class import Experiment
from ModularUtils.FunctionsConstant import getdoKey, plot_lines
from ModularUtils.Functions_Plot_Results import plot_saved_results
from Train_By_Components.Causal_TrainGraph import set_trainGraph

Exp = Experiment("Exp1", set_trainGraph,
                 new_experiment=False,
                 features=["feature"],
                 Data_intervs=[{}])

# cur_mech = "Ythick"
# cur_mechs = ["X1", "X2", "W", "Ycolor"]
all_mechs = [["W0"], ["W1"], ["X0","W0", "W1", "Y0"], ["X1","X2","W1", "Y1"]]
# all_mechs = []
all_mechs.append(Exp.label_names)

features = ["feature"]
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


last_exp_sync = "sync_exp_path"
last_exp_async = "async_exp_path"




mech_paths=[last_exp_sync, last_exp_async]

tvd_diff = {}
kl_diff = {}

for mech in all_mechs:
    query = getdoKey(mech, {})
    tvd_diff[query] = []
    kl_diff[query] = []


tvd_diff["sync"+getdoKey(Exp.label_names, {})] = []
kl_diff["sync"+getdoKey(Exp.label_names, {})] = []

min_val= 1e6

path = last_exp_sync
dist= query = getdoKey(Exp.label_names, {})
if os.path.exists(path + "/tvd/" + dist)==True:
    tvds = torch.load(path + "/tvd/" + dist).detach().cpu().numpy().tolist()
    tvd_diff["sync"+dist]+=tvds
    kls = torch.load(path + "/kl/" + dist).detach().cpu().numpy().tolist()
    kl_diff["sync"+dist]+=kls
    min_val = min(min_val, len(tvds))

path = last_exp_async
for dist in tvd_diff:
    if os.path.exists(path + "/tvd/" + dist)==False:
        continue
    tvds = torch.load(path + "/tvd/" + dist).detach().cpu().numpy().tolist()
    tvd_diff[dist]+=tvds
    kls = torch.load(path + "/kl/" + dist).detach().cpu().numpy().tolist()
    kl_diff[dist]+=kls
    min_val = min(min_val, len(tvds))

for dist in tvd_diff:
    tvd_diff[dist]= tvd_diff[dist][-min_val:]
    kl_diff[dist]= kl_diff[dist][-min_val:]


# label_keys= tvd_diff.keys()
label_keys=["P(W0|X0)", "P(W1|X0,X1,X2)", "P(X0,W0,W1,Y0)", "P(X1,X2,W1,Y1)", "Async_P(V)", "Sync_P(V)"]
# label_keys=["P(Y1,Y2,Thick)", "P(Y1,Y2,Thick|do(X1=0)", "P(Y1,Y2,Thick|do(X1=1)", "P(X1,X2,W,Y1,Y2,Color,Thick)", "P(X2,W,Y1,Y2,Color,Thick|do(X1)"]
plot_lines("WhatIf-GAN Modular Training", "Total Variation Distance",
           list(tvd_diff.values()),
           list(label_keys), save_plot=False,
           path="")


plot_lines("WhatIf-GAN Modular Training", "KL Divergence",
           list(kl_diff.values()),
           list(label_keys), save_plot=False,
           path="")





